import math
import numpy as np
import pandas as pd
from scipy import stats
from statsmodels.stats.weightstats import ztest
from threading import Thread, Lock
from common.database import OLAP

LIMIT = 5e5


def thread_func(func, meta, meta_list, source_table, cols, method, condition, observed, results, lock):
    n = len(cols)
    if func == "stats_index":
        result = stats_index(meta_list[0], source_table, cols)
    elif "normal_test" in func:
        result = normal_test(meta_list)
    elif func == "arg_test":
        if n == 1:
            equal_var = 1
        else:
            equal_var = results["levene_flag"]
        result = arg_test(meta_list, n, method, condition, observed, equal_var)
    elif func == "var_analysis":
        result = var_analysis(meta_list, n)
    elif func == "reliability_analysis":
        result = reliability_analysis(meta, n)
    elif func == "corr_analysis":
        if results["normal_flag"]:
            corr_func = "Pearson"
        else:
            corr_func = "Spearman"
        result = corr_analysis(meta, corr_func)
    lock.acquire()
    results[func] = result
    lock.release()


def diff_funcs_thread(funcs, meta, meta_list, source_table, cols, method, condition, observed, results, lock):
    thread_list = []
    for func in funcs:
        thread = Thread(target=thread_func,
                        args=(func, meta, meta_list, source_table, cols, method, condition, observed, results, lock))
        thread.start()
        thread_list.append(thread)
    for thread in thread_list:
        thread.join()


def same_funcs_thread(func, meta_list, source_table, cols, method, condition, observed, results, lock):
    thread_list = []
    for i, col in enumerate(cols):
        thread = Thread(target=thread_func,
                        args=(
                            func + "_" + col, None, meta_list[i], source_table, col, method, condition, observed,
                            results,
                            lock))
        thread.start()
        thread_list.append(thread)
    for thread in thread_list:
        thread.join()


def stats_index(meta, source_table, cols):
    num = len(meta)
    mean = np.mean(meta)
    std = np.std(meta)
    var = std ** 2

    if num == LIMIT:
        col = cols[0]
        sql = "select sum({}) as sum, min({}) as min, max({}) as max from {}".format(col, col, col, source_table)
        index_meta = OLAP.execute_query(sql)
        sum, min, max = index_meta[0]
    else:
        sum = np.sum(meta)
        min = np.min(meta)
        max = np.max(meta)

    result = {
        "sum": sum,
        "mean": mean,
        "median": np.median(meta),
        "mode": stats.mode(meta)[0][0],
        "min": min,
        "max": max,
        "var": var,
        "std": std,
        "CI": list(stats.t.interval(0.95, num - 1, loc=mean, scale=stats.sem(meta)))
    }
    return result


def normal_test(meta):
    result = {}
    n = len(meta)
    data = pd.DataFrame(meta)
    skew = data.skew()
    kurtosis = data.kurt()
    result['skew'] = skew.values[0]  # 偏度
    result['kurtosis'] = kurtosis.values[0]  # 峰度
    if abs(result['skew']) >= 1 or abs(result['kurtosis']) >= 1:
        p = 0
    else:
        if n < 500:
            stat, p = stats.shapiro(meta)
        else:
            stat, p = stats.kstest(meta, "norm", args=(np.mean(meta), np.std(meta)))
        if math.isnan(p):
            p = 0
    result['p'] = p
    return result


def norm_test(df, cols):
    not_norm = []
    alpha = 0.05
    for col in cols:
        data = df[col].dropna()
        if normal_test(data)["p"] < alpha:
            not_norm.append(col)
    return not_norm


def get_norm_bins(data):
    num_bin = 11
    hist, bin = np.histogram(data, bins=num_bin)
    bin = np.linspace(min(bin), max(bin), num_bin)

    x_norm = np.linspace(-3, 3, 11)
    y = stats.norm.pdf(x_norm, 0, 1)
    return hist, bin, y


def arg_test(meta, n, method, condition, observed, equal_var):
    m = len(meta[0])
    if m >= 30:
        if condition == ">":
            alternative = 'smaller'
        elif condition == "<":
            alternative = 'larger'
        else:
            alternative = 'two-sided'
        if n == 1:
            _, p = ztest(meta[0], None, observed, alternative)
        else:
            _, p = ztest(meta[0], meta[1], 0, alternative)
    else:
        if n == 1:
            _, p = stats.ttest_1samp(meta[0], observed)
        else:
            if method == "ind":
                _, p = stats.ttest_ind(meta[0], meta[1], equal_var=equal_var)
            elif method == "rel":
                _, p = stats.ttest_rel(meta[0], meta[1])
            else:
                return math.nan
        if condition == ">":
            p /= 2
        elif condition == "<":
            p = 1 - p / 2
    return p


def var_analysis(meta, n):
    _, p = stats.f_oneway(*meta)
    return p


def reliability_analysis(meta, n):
    total_row = np.sum(meta, axis=1)
    sy = np.var(total_row)
    var_column = np.var(meta, axis=0)
    si = np.sum(var_column)

    alpha = (n / (n - 1)) * ((sy - si) / sy)
    return alpha


def corr_analysis(meta, corr_func):
    if corr_func == "Pearson":
        corr = np.corrcoef(meta.T)
    else:
        corr, _ = stats.spearmanr(meta)
    if type(corr) != np.ndarray:
        return [1, corr, corr, 1]
    n = meta.shape[1]
    for i in range(n):
        corr[i, i] = 1
    result = list(corr.flatten())
    return result


def levene_test(meta, n):
    alpha = 0.05
    _, p = stats.levene(*meta)
    levene_flag = p > alpha
    return int(levene_flag)


def set_decimal(results, decimal):
    for k, v in results.items():
        if type(v) == list:
            for i in range(len(v)):
                v[i] = round(v[i], decimal)
        elif type(v) == dict:
            for k2, v2 in v.items():
                if type(v2) == list:
                    for j in range(len(v2)):
                        v2[j] = round(v2[j], decimal)
                else:
                    v[k2] = round(v2, decimal)
        else:
            results[k] = round(v, decimal)


def begin_alg(source_table, func, feature_col, method, condition, observed, alpha):
    results = {}
    check_results = {}
    funcs = []

    cols = feature_col.split(",")
    cols = [f"\"{col}\"" for col in cols]
    feature_col = ",".join(cols)
    n = len(cols)

    sql = "select {} from {} limit {}".format(feature_col, source_table, LIMIT)
    df = OLAP.execute_query(sql)
    meta = np.asarray(df, float)

    meta_list = [meta[:, i][~np.isnan(meta[:, i])] for i in range(n)]

    results["num_flag"] = 1
    nums = [len(data) for data in meta_list]
    if max(nums) < 3:
        results["num_flag"] = 0
        return results
    if min(nums) >= 30:
        results["num_flag"] = 2

    normal_flag = 1
    equal_flag = 1
    same_funcs_thread("normal_test", meta_list, "", cols, "", "", 0, check_results, Lock())
    for result in check_results.values():
        if result["p"] <= alpha:
            normal_flag = 0
            break
    if n == 1:
        bins = 11
        normal_result = list(check_results.values())[0]
        hist, bin = np.histogram(meta_list[0], bins=bins)
        bin = np.linspace(min(bin), max(bin), bins)

        x_norm = np.linspace(-3, 3, 11)
        y = stats.norm.pdf(x_norm, 0, 1)

        results["hist"] = list(hist)
        results["bin"] = list(bin)
        results["normal"] = list(y)
    else:
        for i in range(meta.shape[1] - 1):
            if len(meta_list[i]) != len(meta_list[i + 1]):
                equal_flag = 0
                break
        results["equal_flag"] = equal_flag

        levene_flag = levene_test(meta_list, n)
        results["levene_flag"] = levene_flag
    results["normal_flag"] = normal_flag

    if func == "":
        if n == 1:
            funcs.append("stats_index")
            results["normal_test"] = normal_result
        if n < 3 and normal_flag and (method == 'rel' and equal_flag or method == 'ind'):
            funcs.append("arg_test")
        if n > 1 and normal_flag and levene_flag:
            funcs.append("var_analysis")
        if n > 1 and equal_flag:
            funcs.append("reliability_analysis")
            funcs.append("corr_analysis")
        diff_funcs_thread(funcs, meta, meta_list, source_table, cols, method, condition, observed, results, Lock())
    elif func == "normal_test":
        results["normal_test"] = normal_result
    else:
        diff_funcs_thread([func], meta, meta_list, source_table, cols, method, condition, observed, results, Lock())

    set_decimal(results, 2)
    results['feature_col'] = feature_col
    return str(results)
